#!/bin/bash

# This file is documentation for how to get started with gpt-oss-120b on v5p-64.

# The flow of this file is as follows:
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16), on a separate CPU: 
#    Scanned format is better for training; unscanned format is better for decoding.
# 2. Run logit check, pre-training, fine-tuning, and decoding.

# Example Usage: export HF_TOKEN=<huggingface_access_token>; export BASE_OUTPUT_PATH=<GCS_bucket_path>; bash test_gpt_oss.sh

# The golden logit can be generated by:
# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=openai/gpt-oss-120b --output-path=golden_data_gpt-oss-120b.jsonl --prompts='I love to;Today is a;What is the' --hf-model-path=<local_bf16_path>

set -ex

export MODEL_NAME='gpt-oss-120b'
export TOKENIZER_PATH='openai/gpt-oss-120b'

if [ -z "${BASE_OUTPUT_PATH}" ]; then
    # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
    export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
    echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

# Installing torch for deps in forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Step 1: Checkpoint conversion
# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET 
# Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own
# Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory
if [ -z "${CKPT_DISK_LOCATION}" ]; then
  export CKPT_BUCKET=gs://maxtext-model-checkpoints/gpt-oss-120b/hf-bf16
  gcloud storage cp -r ${CKPT_BUCKET} /tmp
  export CKPT_DISK_LOCATION=/tmp/hf-bf16
fi

# 1.1 Convert checkpoint to `scanned` format, more suitable for training 
JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_gpt_oss_ckpt --base-model-path ${CKPT_DISK_LOCATION} --maxtext-model-path ${BASE_OUTPUT_PATH}/scanned --model-size ${MODEL_NAME}

# 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding
JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_gpt_oss_unscanned_ckpt --base-model-path ${CKPT_DISK_LOCATION} --maxtext-model-path ${BASE_OUTPUT_PATH}/unscanned --model-size ${MODEL_NAME}

# Step 2: 
# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands
export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset

# Test whether the forward pass logits match the golden logits
# default golden_logits_path=/deps/src/MaxText/test_assets/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.5 --rtol=0.5 --max_kl_div=3e-3

# Run pre-training - megablox implementation
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=32

# Run fine-tuning - megablox implementation
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32

# Run supervised fine-tuning - megablox implementation
python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32

# Run decoding - megablox implementation
# Note decode requires the access token for huggingface tokenizer even if the model is not gated
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=128 prompt="I love to" ici_fsdp_parallelism=32 ici_tensor_parallelism=1
